import torch
from PIL import Image

import cn_clip.clip as clip
from cn_clip.clip import load_from_name, available_models

print("Available models:", available_models())
# Available models: ['ViT-B-16', 'ViT-L-14', 'ViT-L-14-336', 'ViT-H-14', 'RN50']

device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = load_from_name("ViT-B-16", device=device, download_root='./app/assets')
model.eval()


def get_object_vector(arg_image, arg_text):
    with torch.no_grad():
        text_features = model.encode_text(arg_text)
        print("文特征值：", text_features)
        # 对特征进行归一化，请使用归一化后的图文特征用于下游任务
        text_features /= text_features.norm(dim=-1, keepdim=True)

        emb = text_features.tolist()
        print("文向量为：", emb)

        image_features = model.encode_image(arg_image)
        print("图特征值：", image_features)
        # 对特征进行归一化，请使用归一化后的图文特征用于下游任务
        image_features /= image_features.norm(dim=-1, keepdim=True)
        emb = image_features.tolist()
        print("图向量为：", emb)


if __name__ == '__main__':
    text_list = ["杯子"]

    text = clip.tokenize(text_list).to(device)

    image_name = "data/test1/杯子8.5M-3024x4032.jpg"
    print(image_name)
    image = preprocess(Image.open(image_name)).unsqueeze(0).to(device)
    get_object_vector(image, text)
